In [ ]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils import spectral_norm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
import glob
import matplotlib.image as mpimg
In [ ]:
# --------------------------------------
# Hyperparameters (Optimized for SNGAN)
# --------------------------------------
EPOCHS = 550
BATCH_SIZE = 128
IMAGE_SIZE = 32
CHANNELS_IMG = 3
LATENT_DIM = 128
EMBED_DIM = 100
LEARNING_RATE = 2e-4
BETA1, BETA2 = 0.0, 0.9
CHECKPOINT_EVERY = 20
In [ ]:
# CIFAR-10 class index for "automobile"
AUTOMOBILE_CLASS_IDX = 1
# --------------------------------------
# Data Loading (CIFAR-10) - AUTOMOBILE ONLY
# --------------------------------------
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
full_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
# Filter to get only automobile images
automobile_indices = [i for i, (_, label) in enumerate(full_dataset) if label == AUTOMOBILE_CLASS_IDX]
automobile_dataset = Subset(full_dataset, automobile_indices)
In [ ]:
# Create dataloader with only automobile images
trainloader = DataLoader(
automobile_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2
)
device = torch.device("mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device {device}')
Using device mps
In [ ]:
# --------------------------------------
# SNGAN Generator with Improved Conditioning
# --------------------------------------
class Generator(nn.Module):
def __init__(self, latent_dim, embed_dim):
super().__init__()
self.label_emb = nn.Embedding(1, embed_dim) # Only need one class - automobile
self.fc = nn.Sequential(
nn.Linear(latent_dim + embed_dim, 4*4*512),
nn.BatchNorm1d(4*4*512),
nn.ReLU(True)
)
# Upsampling to 32x32 with improved architecture
self.deconv = nn.Sequential(
# 4x4 -> 8x8
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 8x8 -> 16x16
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 16x16 -> 32x32
nn.ConvTranspose2d(128, CHANNELS_IMG, 4, 2, 1),
nn.Tanh()
)
def forward(self, z, labels=None):
# For automobile-only generation, we can use a fixed label
if labels is None:
labels = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
label_emb = self.label_emb(labels)
x = torch.cat([z, label_emb], dim=1)
x = self.fc(x)
x = x.view(-1, 512, 4, 4)
return self.deconv(x)
In [ ]:
# --------------------------------------
# SNGAN Discriminator (Spectral Norm) with Improved Conditioning
# --------------------------------------
class Discriminator(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.label_emb = nn.Embedding(1, embed_dim) # Only need one class - automobile
# Improved discriminator with spectral normalization
self.conv = nn.Sequential(
# 32x32 -> 16x16
spectral_norm(nn.Conv2d(CHANNELS_IMG, 64, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
# 16x16 -> 8x8
spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
# 8x8 -> 4x4
spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True),
# 4x4 -> 2x2
spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
nn.LeakyReLU(0.2, inplace=True)
)
# Final layer with conditioning
self.fc = spectral_norm(nn.Linear(512*2*2 + embed_dim, 1))
def forward(self, x, labels=None):
bsz = x.size(0)
# For automobile-only discrimination, we can use a fixed label
if labels is None:
labels = torch.zeros(bsz, dtype=torch.long, device=x.device)
features = self.conv(x).view(bsz, -1)
label_emb = self.label_emb(labels)
combined = torch.cat([features, label_emb], dim=1)
return self.fc(combined)
In [ ]:
# --------------------------------------
# Initialize Models & Optimizers
# --------------------------------------
gen = Generator(LATENT_DIM, EMBED_DIM).to(device)
disc = Discriminator(EMBED_DIM).to(device)
criterion = nn.BCEWithLogitsLoss()
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(BETA1, BETA2))
In [ ]:
# Checkpoint loading
start_epoch = 1
checkpoint_path = "adl_part2.pt"
# --------------------------------------
# Check for Existing Checkpoint
# --------------------------------------
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=device)
gen.load_state_dict(checkpoint["gen_state_dict"])
disc.load_state_dict(checkpoint["disc_state_dict"])
opt_gen.load_state_dict(checkpoint["opt_gen_state_dict"])
opt_disc.load_state_dict(checkpoint["opt_disc_state_dict"])
start_epoch = checkpoint["epoch"] + 1
print(f"Loaded checkpoint from epoch {start_epoch-1}")
Loaded checkpoint from epoch 500
In [ ]:
# --------------------------------------
# Utility: Generate & Show 10 Samples
# --------------------------------------
def generate_and_show_samples(epoch):
gen.eval()
with torch.no_grad():
z = torch.randn(10, LATENT_DIM, device=device)
samples = gen(z).cpu()
samples = (samples + 1) / 2.0
fig, axes = plt.subplots(1, 10, figsize=(22, 2.4))
for i in range(10):
img = samples[i].permute(1, 2, 0).numpy()
axes[i].imshow(img)
axes[i].axis('off')
plt.suptitle(f"Epoch {epoch}: Generated Automobiles", fontsize=14)
plt.savefig(f"automobile_samples_epoch_{epoch}.png")
plt.show()
gen.train()
In [ ]:
# --------------------------------------
# Compute IS & FID with TorchMetrics
# --------------------------------------
def compute_is_fid(generator, loader, n_samples=2000):
is_metric = InceptionScore().to("cpu")
fid_metric = FrechetInceptionDistance().to("cpu")
generator.eval()
real_count = 0
for real_imgs, _ in loader:
real_imgs = real_imgs.to(device)
real_imgs_uint8 = (((real_imgs * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
fid_metric.update(real_imgs_uint8, real=True)
real_count += real_imgs.size(0)
if real_count >= n_samples:
break
fake_count = 0
with torch.no_grad():
while fake_count < n_samples:
z = torch.randn(min(BATCH_SIZE, n_samples - fake_count), LATENT_DIM, device=device)
fake_out = generator(z)
fake_out_uint8 = (((fake_out * 0.5) + 0.5) * 255).to(torch.uint8).cpu()
is_metric.update(fake_out_uint8)
fid_metric.update(fake_out_uint8, real=False)
fake_count += z.size(0)
inception_score = is_metric.compute() # (mean, std)
fid_score = fid_metric.compute()
generator.train()
return inception_score[0].item(), fid_score.item()
In [ ]:
# --------------------------------------
# Training
# --------------------------------------
g_losses, d_losses = [], []
for epoch in range(start_epoch, EPOCHS + 1):
epoch_g_losses, epoch_d_losses = [], []
for _, (real, _) in enumerate(trainloader):
real = real.to(device)
bsz = real.size(0)
# Train Discriminator
disc.zero_grad()
noise = torch.randn(bsz, LATENT_DIM, device=device)
# Real images (all are automobiles)
pred_real = disc(real)
loss_real = criterion(pred_real, torch.ones_like(pred_real))
# Fake images
fake = gen(noise)
pred_fake = disc(fake.detach())
loss_fake = criterion(pred_fake, torch.zeros_like(pred_fake))
lossD = loss_real + loss_fake
lossD.backward()
opt_disc.step()
epoch_d_losses.append(lossD.item())
# Train Generator
gen.zero_grad()
pred_gen = disc(fake)
lossG = criterion(pred_gen, torch.ones_like(pred_gen))
lossG.backward()
opt_gen.step()
epoch_g_losses.append(lossG.item())
# Calculate average losses for the epoch
avg_g_loss = sum(epoch_g_losses) / len(epoch_g_losses)
avg_d_loss = sum(epoch_d_losses) / len(epoch_d_losses)
g_losses.append(avg_g_loss)
d_losses.append(avg_d_loss)
print(f"[Epoch {epoch}/{EPOCHS}] LossD: {avg_d_loss:.4f} LossG: {avg_g_loss:.4f}")
# Save, visualize, compute IS/FID at checkpoints
if epoch % CHECKPOINT_EVERY == 0:
data_to_save = {
"epoch": epoch,
"gen_state_dict": gen.state_dict(),
"disc_state_dict": disc.state_dict(),
"opt_gen_state_dict": opt_gen.state_dict(),
"opt_disc_state_dict": opt_disc.state_dict(),
"g_losses": g_losses,
"d_losses": d_losses
}
torch.save(data_to_save, checkpoint_path)
print(f"[epoch={epoch}] Checkpoint saved: {checkpoint_path}")
generate_and_show_samples(epoch)
# Compute metrics
try:
is_val, fid_val = compute_is_fid(gen, trainloader)
print(f"==> Epoch {epoch}: Inception Score = {is_val:.4f}, FID = {fid_val:.4f}")
except Exception as e:
print(f"Error computing metrics: {e}")
print("Training complete!")
[Epoch 501/550] LossD: 1.2422 LossG: 0.8610 [Epoch 502/550] LossD: 1.2375 LossG: 0.8608 [Epoch 503/550] LossD: 1.2366 LossG: 0.8723 [Epoch 504/550] LossD: 1.2385 LossG: 0.8631 [Epoch 505/550] LossD: 1.2430 LossG: 0.8634 [Epoch 506/550] LossD: 1.2413 LossG: 0.8667 [Epoch 507/550] LossD: 1.2389 LossG: 0.8611 [Epoch 508/550] LossD: 1.2348 LossG: 0.8667 [Epoch 509/550] LossD: 1.2404 LossG: 0.8688 [Epoch 510/550] LossD: 1.2421 LossG: 0.8608 [Epoch 511/550] LossD: 1.2417 LossG: 0.8590 [Epoch 512/550] LossD: 1.2413 LossG: 0.8607 [Epoch 513/550] LossD: 1.2391 LossG: 0.8620 [Epoch 514/550] LossD: 1.2373 LossG: 0.8610 [Epoch 515/550] LossD: 1.2384 LossG: 0.8672 [Epoch 516/550] LossD: 1.2381 LossG: 0.8715 [Epoch 517/550] LossD: 1.2483 LossG: 0.8597 [Epoch 518/550] LossD: 1.2396 LossG: 0.8585 [Epoch 519/550] LossD: 1.2393 LossG: 0.8665 [Epoch 520/550] LossD: 1.2391 LossG: 0.8650 [epoch=520] Checkpoint saved: adl_part2.pt
/Users/shivamsahil/Downloads/bits/assignments/venv/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `InceptionScore` will save all extracted features in buffer. For large datasets this may lead to large memory footprint. warnings.warn(*args, **kwargs) # noqa: B028
==> Epoch 520: Inception Score = 3.0748, FID = 83.1932 [Epoch 521/550] LossD: 1.2403 LossG: 0.8559 [Epoch 522/550] LossD: 1.2426 LossG: 0.8677 [Epoch 523/550] LossD: 1.2364 LossG: 0.8648 [Epoch 524/550] LossD: 1.2446 LossG: 0.8790 [Epoch 525/550] LossD: 1.2414 LossG: 0.8541 [Epoch 526/550] LossD: 1.2367 LossG: 0.8642 [Epoch 527/550] LossD: 1.2351 LossG: 0.8661 [Epoch 528/550] LossD: 1.2436 LossG: 0.8595 [Epoch 529/550] LossD: 1.2361 LossG: 0.8591 [Epoch 530/550] LossD: 1.2425 LossG: 0.8593 [Epoch 531/550] LossD: 1.2403 LossG: 0.8597 [Epoch 532/550] LossD: 1.2382 LossG: 0.8658 [Epoch 533/550] LossD: 1.2399 LossG: 0.8597 [Epoch 534/550] LossD: 1.2350 LossG: 0.8632 [Epoch 535/550] LossD: 1.2372 LossG: 0.8640 [Epoch 536/550] LossD: 1.2392 LossG: 0.8557 [Epoch 537/550] LossD: 1.2438 LossG: 0.8636 [Epoch 538/550] LossD: 1.2390 LossG: 0.8611 [Epoch 539/550] LossD: 1.2391 LossG: 0.8711 [Epoch 540/550] LossD: 1.2337 LossG: 0.8539 [epoch=540] Checkpoint saved: adl_part2.pt
==> Epoch 540: Inception Score = 3.1830, FID = 75.2091 [Epoch 541/550] LossD: 1.2420 LossG: 0.8673 [Epoch 542/550] LossD: 1.2424 LossG: 0.8556 [Epoch 543/550] LossD: 1.2356 LossG: 0.8577 [Epoch 544/550] LossD: 1.2357 LossG: 0.8613 [Epoch 545/550] LossD: 1.2385 LossG: 0.8711 [Epoch 546/550] LossD: 1.2392 LossG: 0.8536 [Epoch 547/550] LossD: 1.2412 LossG: 0.8608 [Epoch 548/550] LossD: 1.2398 LossG: 0.8599 [Epoch 549/550] LossD: 1.2420 LossG: 0.8645 [Epoch 550/550] LossD: 1.2401 LossG: 0.8617 Training complete!
Display all Results at checkpoints¶
In [ ]:
directory = r'task2'
# Define a custom sort key that extracts the epoch number
def extract_epoch(filename):
base = os.path.basename(filename)
try:
# Assuming filename format: automobile_samples_epoch_{epoch_number}.png
epoch_str = base.split('automobile_samples_epoch_')[1].split('.')[0]
return int(epoch_str)
except (IndexError, ValueError):
return float('inf') # Place any files that don't match the pattern at the end
png_files = glob.glob(os.path.join(directory, '*.png'))
# Sort the list numerically by epoch number
png_files = sorted(png_files, key=extract_epoch)
# Check if any PNG files are found
if not png_files:
print("No PNG files found in the directory:", directory)
else:
n = len(png_files)
# Increase the figure size to accommodate full screen-like display
fig, axs = plt.subplots(n, 1, figsize=(22, 2.4 * n))
if n == 1:
axs = [axs]
mng = plt.get_current_fig_manager()
try:
mng.window.state('zoomed')
except AttributeError:
try:
mng.window.showMaximized()
except Exception:
pass # If it fails, the figure will remain at the set figsize
# Loop through each file and display the image
for ax, file in zip(axs, png_files):
img = mpimg.imread(file)
ax.imshow(img, aspect='auto')
ax.axis('off')
ax.set_title(os.path.basename(file), fontsize=14)
plt.tight_layout()
plt.show()
In [ ]:
# Install necessary packages
!apt-get install texlive texlive-xetex texlive-latex-extra pandoc
!pip install pypandoc
# Mount Google Drive
from google.colab import drive
drive.mount("/content/drive", force_remount=True)
# Copy the notebook to the current directory
!cp 'drive/My Drive/Colab Notebooks/Assignment2_Group75_Task2.ipynb' ./
# Convert the notebook to PDF while keeping the code and output
!jupyter nbconvert --to html "Assignment2_Group75_Task2.ipynb"
# Download the generated PDF
from google.colab import files
files.download('Assignment2_Group75_Task2.html')
done. Collecting pypandoc Downloading pypandoc-1.15-py3-none-any.whl.metadata (16 kB) Downloading pypandoc-1.15-py3-none-any.whl (21 kB) Installing collected packages: pypandoc Successfully installed pypandoc-1.15